{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2f67f995",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-28T13:06:58.026548Z",
     "start_time": "2023-11-28T13:06:58.020728Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "405edb19",
   "metadata": {},
   "source": [
    "## 理解神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d5d8c6d",
   "metadata": {},
   "source": [
    "- 从shape的含义，shape 的变化，可能是理解神经网络模型结构乃至处理过程的一个很实用的切入。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0448091c",
   "metadata": {},
   "source": [
    "## `nn.utils.clip_grad_norm_`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c054c26",
   "metadata": {},
   "source": [
    "- 避免梯度爆炸\n",
    "   \n",
    "    - 计算所有梯度的范数：这一步骤涉及到计算网络中所有参数的梯度范数。通常使用 L2 范数，即平方和的平方根。\n",
    "        - 也是 `nn.utils.clip_grad_norm_` 的返回值；\n",
    "    - 比较梯度范数与最大值：将计算出的梯度范数与设定的最大范数值比较。\n",
    "    - 裁剪梯度：如果梯度范数大于最大值，那么将所有梯度缩放到最大值以内。这是通过乘以一个缩放因子实现的，缩放因子是最大范数值除以梯度范数。\n",
    "    - 更新梯度：使用裁剪后的梯度值更新网络参数。\n",
    "\n",
    "- 一般training过程中用到：\n",
    "\n",
    "    ```\n",
    "    optimizer.zero_grad()\n",
    "    # 反向传播计算 parameters 的 grad\n",
    "    loss.backward()\n",
    "    # 计算完梯度之后，norm 所有参数的 grads\n",
    "    nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)\n",
    "    # 基于更新后的 grad 值来更新参数\n",
    "    optimizer.step()\n",
    "    ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c6a0eb31",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-28T13:11:08.362707Z",
     "start_time": "2023-11-28T13:11:08.352229Z"
    }
   },
   "outputs": [],
   "source": [
    "# 定义一个简单的网络及其梯度\n",
    "class SimpleNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SimpleNet, self).__init__()\n",
    "        self.param1 = torch.nn.Parameter(torch.tensor([3.0, 4.0]))\n",
    "        self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0]))\n",
    "\n",
    "# 创建网络实例\n",
    "net = SimpleNet()\n",
    "\n",
    "# 设置梯度\n",
    "net.param1.grad = torch.tensor([3.0, 4.0])\n",
    "net.param2.grad = torch.tensor([1.0, 2.0])\n",
    "\n",
    "# 最大梯度范数\n",
    "max_norm = 5.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c7b9ad30",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-28T13:11:10.028971Z",
     "start_time": "2023-11-28T13:11:10.016953Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(5.4772), [tensor([2.7386, 3.6515]), tensor([0.9129, 1.8257])])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算梯度范数\n",
    "total_norm = torch.sqrt(sum(p.grad.norm()**2 for p in net.parameters()))\n",
    "\n",
    "# 计算缩放因子\n",
    "scale = max_norm / (total_norm + 1e-6)\n",
    "\n",
    "# 应用梯度裁剪\n",
    "for p in net.parameters():\n",
    "    p.grad.data.mul_(scale)\n",
    "\n",
    "# 更新后的梯度\n",
    "clipped_grads = [p.grad for p in net.parameters()]\n",
    "\n",
    "total_norm, clipped_grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d7143022",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-28T13:11:11.532229Z",
     "start_time": "2023-11-28T13:11:11.522275Z"
    }
   },
   "outputs": [],
   "source": [
    "# 定义一个简单的网络及其梯度\n",
    "class SimpleNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SimpleNet, self).__init__()\n",
    "        self.param1 = torch.nn.Parameter(torch.tensor([3.0, 4.0]))\n",
    "        self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0]))\n",
    "\n",
    "# 创建网络实例\n",
    "net = SimpleNet()\n",
    "\n",
    "# 设置梯度\n",
    "net.param1.grad = torch.tensor([3.0, 4.0])\n",
    "net.param2.grad = torch.tensor([1.0, 2.0])\n",
    "\n",
    "# 最大梯度范数\n",
    "max_norm = 5.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "15d074eb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-28T13:11:14.160926Z",
     "start_time": "2023-11-28T13:11:14.151739Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(5.4772)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.utils.clip_grad_norm_(net.parameters(), max_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1a1bd7a7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-28T13:11:15.181539Z",
     "start_time": "2023-11-28T13:11:15.173888Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2.7386, 3.6515])\n",
      "tensor([0.9129, 1.8257])\n"
     ]
    }
   ],
   "source": [
    "for p in net.parameters():\n",
    "    print(p.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fd5152",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
